{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "原文代码作者：https://github.com/wzyonggege/statistical-learning-method\n",
    "\n",
    "中文注释制作：机器学习初学者(微信公众号：ID:ai-start-com)\n",
    "\n",
    "配置环境：python 3.6\n",
    "\n",
    "代码全部测试通过。\n",
    "![gongzhong](../gongzhong.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第9章 EM算法及其推广"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Expectation Maximization algorithm\n",
    "\n",
    "### Maximum likehood function\n",
    "\n",
    "[likehood & maximum likehood](http://fangs.in/post/thinkstats/likelihood/)\n",
    "\n",
    "> 在统计学中，似然函数（likelihood function，通常简写为likelihood，似然）是一个非常重要的内容，在非正式场合似然和概率（Probability）几乎是一对同义词，但是在统计学中似然和概率却是两个不同的概念。概率是在特定环境下某件事情发生的可能性，也就是结果没有产生之前依据环境所对应的参数来预测某件事情发生的可能性，比如抛硬币，抛之前我们不知道最后是哪一面朝上，但是根据硬币的性质我们可以推测任何一面朝上的可能性均为50%，这个概率只有在抛硬币之前才是有意义的，抛完硬币后的结果便是确定的；而似然刚好相反，是在确定的结果下去推测产生这个结果的可能环境（参数），还是抛硬币的例子，假设我们随机抛掷一枚硬币1,000次，结果500次人头朝上，500次数字朝上（实际情况一般不会这么理想，这里只是举个例子），我们很容易判断这是一枚标准的硬币，两面朝上的概率均为50%，这个过程就是我们运用出现的结果来判断这个事情本身的性质（参数），也就是似然。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$P(Y|\\theta) = \\prod[\\pi p^{y_i}(1-p)^{1-y_i}+(1-\\pi) q^{y_i}(1-q)^{1-y_i}]$$\n",
    "\n",
    "### E step:\n",
    "\n",
    "$$\\mu^{i+1}=\\frac{\\pi (p^i)^{y_i}(1-(p^i))^{1-y_i}}{\\pi (p^i)^{y_i}(1-(p^i))^{1-y_i}+(1-\\pi) (q^i)^{y_i}(1-(q^i))^{1-y_i}}$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "pro_A, pro_B, por_C = 0.5, 0.5, 0.5\n",
    "\n",
    "def pmf(i, pro_A, pro_B, por_C):\n",
    "    pro_1 = pro_A * math.pow(pro_B, data[i]) * math.pow((1-pro_B), 1-data[i])\n",
    "    pro_2 = pro_A * math.pow(pro_C, data[i]) * math.pow((1-pro_C), 1-data[i])\n",
    "    return pro_1 / (pro_1 + pro_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### M step:\n",
    "\n",
    "$$\\pi^{i+1}=\\frac{1}{n}\\sum_{j=1}^n\\mu^{i+1}_j$$\n",
    "\n",
    "$$p^{i+1}=\\frac{\\sum_{j=1}^n\\mu^{i+1}_jy_i}{\\sum_{j=1}^n\\mu^{i+1}_j}$$\n",
    "\n",
    "$$q^{i+1}=\\frac{\\sum_{j=1}^n(1-\\mu^{i+1}_jy_i)}{\\sum_{j=1}^n(1-\\mu^{i+1}_j)}$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EM:\n",
    "    def __init__(self, prob):\n",
    "        self.pro_A, self.pro_B, self.pro_C = prob\n",
    "        \n",
    "    # e_step\n",
    "    def pmf(self, i):\n",
    "        pro_1 = self.pro_A * math.pow(self.pro_B, data[i]) * math.pow((1-self.pro_B), 1-data[i])\n",
    "        pro_2 = (1 - self.pro_A) * math.pow(self.pro_C, data[i]) * math.pow((1-self.pro_C), 1-data[i])\n",
    "        return pro_1 / (pro_1 + pro_2)\n",
    "    \n",
    "    # m_step\n",
    "    def fit(self, data):\n",
    "        count = len(data)\n",
    "        print('init prob:{}, {}, {}'.format(self.pro_A, self.pro_B, self.pro_C))\n",
    "        for d in range(count):\n",
    "            _ = yield\n",
    "            _pmf = [self.pmf(k) for k in range(count)]\n",
    "            pro_A = 1/ count * sum(_pmf)\n",
    "            pro_B = sum([_pmf[k]*data[k] for k in range(count)]) / sum([_pmf[k] for k in range(count)])\n",
    "            pro_C = sum([(1-_pmf[k])*data[k] for k in range(count)]) / sum([(1-_pmf[k]) for k in range(count)])\n",
    "            print('{}/{}  pro_a:{:.3f}, pro_b:{:.3f}, pro_c:{:.3f}'.format(d+1, count, pro_A, pro_B, pro_C))\n",
    "            self.pro_A = pro_A\n",
    "            self.pro_B = pro_B\n",
    "            self.pro_C = pro_C\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "data=[1,1,0,1,0,0,1,0,1,1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "init prob:0.5, 0.5, 0.5\n"
     ]
    }
   ],
   "source": [
    "em = EM(prob=[0.5, 0.5, 0.5])\n",
    "f = em.fit(data)\n",
    "next(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/10  pro_a:0.500, pro_b:0.600, pro_c:0.600\n"
     ]
    }
   ],
   "source": [
    "# 第一次迭代\n",
    "f.send(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2/10  pro_a:0.500, pro_b:0.600, pro_c:0.600\n"
     ]
    }
   ],
   "source": [
    "# 第二次\n",
    "f.send(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "init prob:0.4, 0.6, 0.7\n"
     ]
    }
   ],
   "source": [
    "em = EM(prob=[0.4, 0.6, 0.7])\n",
    "f2 = em.fit(data)\n",
    "next(f2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/10  pro_a:0.406, pro_b:0.537, pro_c:0.643\n"
     ]
    }
   ],
   "source": [
    "f2.send(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2/10  pro_a:0.406, pro_b:0.537, pro_c:0.643\n"
     ]
    }
   ],
   "source": [
    "f2.send(2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
